// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_VECTORUNIT_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_VECTORUNIT_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"
#include "src/LayerNormUnit.h"
#include "src/OutputAddressGenerator.h"
#include "src/SoftmaxUnit.h"
#include "src/VectorFetchUnit.h"

template <typename DTYPE, int WIDTH, int NROWS>
SC_MODULE(ScaleUnit) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorIn);
  Connections::Out<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorOut);

  SC_CTOR(ScaleUnit) {
    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void run() {
    paramsIn.Reset();
    vectorIn.Reset();
    vectorOut.Reset();

    wait();

    while (true) {
      VectorParams params = paramsIn.Pop();

      int tensorSize = 1;
#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 3; j++) {
          tensorSize *= params.loops[i][j];
        }
      }

#pragma hls_pipeline_init_interval 1
      for (int i = 0; i < tensorSize; i++) {
        Pack1D<DTYPE, NROWS> vec = vectorIn.Pop();
        if (params.SCALE) {
#pragma hls_unroll yes
          for (int i = 0; i < NROWS; i++) {
            vec[i] *= params.SCALE_FACTOR;
          }
        }

        if (params.RELU) {
#pragma hls_unroll yes
          for (int i = 0; i < NROWS; i++) {
            if (vec[i] < 0) {
              vec[i] = 0;
            }
          }
        }

        vectorOut.Push(vec);
      }
    }
  }
};

template <typename DTYPE, int WIDTH, int NROWS>
SC_MODULE(VectorUnit) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::In<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(systolicArrayOutput);
  Connections::Out<int> CCS_INIT_S1(vectorFetchAddressRequest);
  Connections::In<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(vectorFetchDataResponse);
  Connections::Out<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(vectorUnitOutput);
  Connections::Out<int> CCS_INIT_S1(outputAddress);
  Connections::SyncOut CCS_INIT_S1(done);

  FetchUnit<NROWS> CCS_INIT_S1(fetchUnit);
  Connections::Combinational<VectorParams> CCS_INIT_S1(fetchUnitParams);

  SoftmaxUnit<DTYPE, NROWS> CCS_INIT_S1(softmaxUnit);
  Connections::Combinational<VectorParams> CCS_INIT_S1(softmaxUnitParams);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(
      softmaxUnitInput);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(
      softmaxUnitOutput);

  LayerNormUnit<DTYPE, NROWS> CCS_INIT_S1(layerNormUnit);
  Connections::Combinational<VectorParams> CCS_INIT_S1(layerNormUnitParams);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(
      layerNormUnitInput);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(
      layerNormUnitOutput);

  ScaleUnit<DTYPE, WIDTH, NROWS> CCS_INIT_S1(scaleUnit);
  Connections::Combinational<VectorParams> CCS_INIT_S1(scaleUnitParams);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(scaleUnitInput);
  Connections::Combinational<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(
      scaleUnitOutput);

  OutputAddressGenerator<DTYPE, NROWS, NROWS> CCS_INIT_S1(
      outputAddressGenerator);
  Connections::Combinational<VectorParams> CCS_INIT_S1(outputAddressGenParams);

  Connections::Combinational<VectorParams> CCS_INIT_S1(inputConnectionParams);

  SC_CTOR(VectorUnit) {
    fetchUnit.clk(clk);
    fetchUnit.rstn(rstn);
    fetchUnit.paramsIn(fetchUnitParams);
    fetchUnit.vectorFetchAddressRequest(vectorFetchAddressRequest);

    scaleUnit.clk(clk);
    scaleUnit.rstn(rstn);
    scaleUnit.paramsIn(scaleUnitParams);
    scaleUnit.vectorIn(scaleUnitInput);
    scaleUnit.vectorOut(scaleUnitOutput);

    softmaxUnit.clk(clk);
    softmaxUnit.rstn(rstn);
    softmaxUnit.paramsIn(softmaxUnitParams);
    softmaxUnit.vectorIn(softmaxUnitInput);
    softmaxUnit.vectorOut(softmaxUnitOutput);

    layerNormUnit.clk(clk);
    layerNormUnit.rstn(rstn);
    layerNormUnit.paramsIn(layerNormUnitParams);
    layerNormUnit.vectorIn(layerNormUnitInput);
    layerNormUnit.vectorOut(layerNormUnitOutput);

    outputAddressGenerator.clk(clk);
    outputAddressGenerator.rstn(rstn);
    outputAddressGenerator.paramsIn(outputAddressGenParams);
    outputAddressGenerator.scaleUnitOutput(scaleUnitOutput);
    outputAddressGenerator.softmaxUnitOutput(softmaxUnitOutput);
    outputAddressGenerator.layerNormUnitOutput(layerNormUnitOutput);
    outputAddressGenerator.outputAddress(outputAddress);
    outputAddressGenerator.outputData(vectorUnitOutput);
    outputAddressGenerator.done(done);

    SC_THREAD(read_params);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(connect_inputs);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void read_params() {
    paramsIn.Reset();
    fetchUnitParams.ResetWrite();
    scaleUnitParams.ResetWrite();
    softmaxUnitParams.ResetWrite();
    layerNormUnitParams.ResetWrite();
    outputAddressGenParams.ResetWrite();
    inputConnectionParams.ResetWrite();

    wait();

    while (true) {
      VectorParams params = paramsIn.Pop();

      if (params.SOFTMAX) {
        softmaxUnitParams.Push(params);
      } else if (params.LAYER_NORM) {
        layerNormUnitParams.Push(params);
      } else {
        scaleUnitParams.Push(params);
      }

      inputConnectionParams.Push(params);
      outputAddressGenParams.Push(params);
    }
  }

  void connect_inputs() {
    inputConnectionParams.ResetRead();
    systolicArrayOutput.Reset();
    vectorFetchDataResponse.Reset();

    scaleUnitInput.ResetWrite();
    softmaxUnitInput.ResetWrite();
    layerNormUnitInput.ResetWrite();

    wait();

    while (true) {
      VectorParams params = inputConnectionParams.Pop();

      int tensorSize = 1;
#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 3; j++) {
          tensorSize *= params.loops[i][j];
        }
      }

#pragma hls_pipeline_init_interval 1
      for (int i = 0; i < tensorSize; i++) {
        if (params.SOFTMAX) {
          softmaxUnitInput.Push(vectorFetchDataResponse.Pop());
        } else if (params.LAYER_NORM) {
          layerNormUnitInput.Push(vectorFetchDataResponse.Pop());
        } else {
          scaleUnitInput.Push(systolicArrayOutput.Pop());
        }
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_VECTORUNIT_H_
